package main

import (
	"bytes"
	"crypto"
	"crypto/aes"
	"crypto/cipher"
	"crypto/ecdh"
	"crypto/ed25519"
	"crypto/rand"
	"crypto/rsa"
	"crypto/sha256"
	"crypto/sha512"
	"crypto/x509"
	"encoding/asn1"
	"encoding/base64"
	"encoding/pem"
	"errors"
	"flag"
	"fmt"
	"io"
	logger "log"
	"math/big"
	"strings"

	"golang.org/x/crypto/ssh"

	"os"
)

var log = logger.New(
	os.Stderr,
	"",
	logger.Ldate|logger.Ltime|logger.Lmicroseconds|logger.LUTC|logger.Lshortfile,
)

func main() {
	// go run main.go -text=111
	publicKey := flag.String("publicKey", "", "-publicKey='加密使用的公钥' : 默认是本机公钥，发信息给对方，就需要对方的公钥")
	rawText := flag.String("text", "", "-text='待加密文本' : 待加密文本")
	encryptText := flag.String("encrypt", "", "-encrypt='待解密文本' : 待解密文本")

	// 定义短别名 -e
	rawTextShort := flag.String("t", "", "-t 对应 -text")
	encryptTextShort := flag.String("e", "", "-e 对应 -encrypt")
	publicKeyShort := flag.String("p", "", "-p 对应 -publicKey")
	sha256 := flag.Bool("sha256", true, "默认使用sha256")
	useEd25519 := flag.Bool("ed", true, "默认使用ed25519密钥")

	flag.Parse()

	// 如果 -e 被使用，将其值赋给 encryptText
	if *encryptTextShort != "" {
		*encryptText = *encryptTextShort
	}

	if *rawTextShort != "" {
		*rawText = *rawTextShort
	}

	if *publicKeyShort != "" {
		*publicKey = *publicKeyShort
	}

	if len(*rawText) == 0 && len(*encryptText) == 0 {
		flag.PrintDefaults()
		return
	}

	// 分支：ed25519 基于 X25519 + AES-GCM 的加解密
	if *useEd25519 {
		dir, err2 := os.UserHomeDir()
		if err2 != nil {
			log.Fatal(err2)
		}

		var err error
		var publicKeyBytes []byte
		if len(*publicKey) > 0 {
			publicKeyBytes = []byte(*publicKey)
		} else {
			// 默认读取 Ed25519 公钥
			var publicKeyPath = dir + "/.ssh/id_ed25519.pub"
			publicKeyBytes, err = os.ReadFile(publicKeyPath)
			if err != nil {
				log.Fatalf("Error reading ed25519 public key file: %v", err)
			}
		}

		// 解析 ED25519 公钥（OpenSSH authorized_keys 格式）
		parsedPub, _, _, _, err := ssh.ParseAuthorizedKey(publicKeyBytes)
		if err != nil {
			log.Fatalf("Error parsing ed25519 public key: %v", err)
		}
		cryptoPub := parsedPub.(ssh.CryptoPublicKey).CryptoPublicKey()
		edPub, ok := cryptoPub.(ed25519.PublicKey)
		if !ok {
			log.Fatalf("public key is not an ED25519 public key")
		}

		var encrypt string
		if len(*rawText) > 0 {
			enc, err := Ed25519Encrypt([]byte(*rawText), edPub)
			if err != nil {
				log.Println("加密错误", err)
			} else {
				encrypt = enc
				log.Println("加密后\n" + encrypt)
			}
		}

		if encryptText != nil && len(*encryptText) > 0 {
			encrypt = *encryptText
		}

		if len(encrypt) > 0 {
			// 读取 Ed25519 私钥（OpenSSH 私钥格式）
			var privateKeyPath = dir + "/.ssh/id_ed25519"
			var privateKey []byte
			if info, err := os.Stat(privateKeyPath); err == nil && !info.IsDir() {
				privateKey, err = os.ReadFile(privateKeyPath)
				if err != nil {
					log.Print(err)
				}
			}

			if len(privateKey) == 0 {
				log.Println("未找到 Ed25519 私钥，无法解密")
				return
			}

			key, err := ssh.ParseRawPrivateKey(privateKey)
			if err != nil {
				log.Println("解析 Ed25519 私钥失败", err)
				return
			}

			var edPriv ed25519.PrivateKey
			switch k := key.(type) {
			case ed25519.PrivateKey:
				edPriv = k
			case *ed25519.PrivateKey:
				edPriv = *k
			default:
				log.Println("private key is not an ED25519 private key")
				return
			}

			plain, err := Ed25519Decrypt(encrypt, edPriv)
			if err != nil {
				log.Println("解密错误", err)
			} else {
				log.Println("解密后\n" + string(plain))
			}
		}
		return
	}

	dir, err2 := os.UserHomeDir()
	if err2 != nil {
		log.Fatal(err2)
	}

	// 下面这个格式不可以调整，否则会出现decode错误
	var err error
	var publicKeyBytes []byte
	if len(*publicKey) > 0 {
		publicKeyBytes = []byte(*publicKey)
	} else {
		// 读取公钥文件
		var publicKeyPath = dir + "/.ssh/id_rsa.pub"
		publicKeyBytes, err = os.ReadFile(publicKeyPath)
		if err != nil {
			log.Fatalf("Error reading public key file: %v", err)
		}
	}

	// 直接是rsa的原文本即可
	// publicKeyBytes = []byte(``)

	var result []byte
	for _, v := range publicKeyBytes {
		if v != 9 { // 去掉byte为 9 的值
			result = append(result, v)
		}
	}
	publicKeyBytes = result

	log.Println("当前公钥 " + string(publicKeyBytes))

	// 解析公钥
	var rsaPublicKey *rsa.PublicKey

	if strings.HasPrefix(string(publicKeyBytes), "ssh-rsa") {
		parsedKey, _, _, _, err := ssh.ParseAuthorizedKey(publicKeyBytes)
		if err != nil {
			log.Fatalf("Error parsing public key: %v", err)
		}
		var ok bool
		rsaPublicKey, ok = parsedKey.(ssh.CryptoPublicKey).CryptoPublicKey().(*rsa.PublicKey)
		if !ok {
			log.Fatalf("Error converting to *rsa.PublicKey")
		}
	} else {

		// 有待调试。需要处理那种 格式不准确的情况

		//if strings.Contains(string(publicKeyBytes), "RSA PUBLIC KEY") {
		//	s, _ := formatPEM(string(publicKeyBytes), "RSA PUBLIC KEY")
		//	publicKeyBytes = []byte(s)
		//}

		block, _ := pem.Decode(publicKeyBytes)
		if block == nil {
			log.Fatalf("public key error %v", publicKey)
		}
		switch block.Type {
		case "RSA PUBLIC KEY":
			pubInterface, err := x509.ParsePKCS1PublicKey(block.Bytes)
			if err != nil {
				log.Fatalf("Error parsing public key by ParsePKCS1PublicKey: %v", err)
			}
			rsaPublicKey = pubInterface
		default:
			pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes)
			if err != nil {
				log.Fatalf("Error parsing public key by ParsePKIXPublicKey: %v", err)

			}
			var ok1 bool
			rsaPublicKey, ok1 = pubInterface.(*rsa.PublicKey)
			if !ok1 {
				log.Fatalf("public key not supported %v", err)
			}
		}
		log.Println("公钥长度", rsaPublicKey.Size()*8)
	}
	key := x509.MarshalPKCS1PublicKey(rsaPublicKey)
	// 将公钥编码为 PEM 格式
	publicKeyModel := pem.EncodeToMemory(&pem.Block{
		Type:  "RSA PUBLIC KEY",
		Bytes: key,
	})

	var privateKeyPath = dir + "/.ssh/id_rsa"

	// 判断文件是否存在
	_, err32 := os.Stat(privateKeyPath)

	var privateKey []byte
	if os.IsNotExist(err32) {
	} else {
		privateKey, err2 = os.ReadFile(privateKeyPath)
		if err2 != nil {
			log.Print(err2)
		}
	}

	xRsa, err := NewXRsa(publicKeyModel, privateKey)
	if err != nil {
		log.Print(err)
	}

	var encrypt string
	if len(*rawText) > 0 {
		if *sha256 {
			encrypt, err = xRsa.PublicEncrypt256(*rawText)
		} else {
			encrypt, err = xRsa.PublicEncrypt(*rawText)
		}
		if err != nil {
			log.Println("加密错误", err2)
		} else {
			log.Println("加密后\n" + encrypt)
		}
	}

	if encryptText != nil && len(*encryptText) > 0 {
		encrypt = *encryptText
	}

	if len(encrypt) > 0 {
		var decrypt string
		if *sha256 {
			decrypt, err = xRsa.PrivateDecrypt256(encrypt)
		} else {
			decrypt, err = xRsa.PrivateDecrypt(encrypt)
		}
		if err != nil {
			log.Println("解密错误", err)
		} else {
			log.Println("解密后\n" + decrypt)
		}
	}

}

func cleanPEM(pem string) string {
	pem = strings.TrimSpace(pem)            // 去除首尾空格
	pem = strings.ReplaceAll(pem, "\n", "") // 去除换行符
	pem = strings.ReplaceAll(pem, "\r", "") // 去除回车符
	return pem
}

// formatPEM formats a Base64 encoded key into PEM format.
// If the input is already in PEM format, it returns the original string.
// If the input is invalid, it returns an error.
func formatPEM(input, pemType string) (string, error) {

	input = cleanPEM(input)

	// Check if the input is already in PEM format
	if isPEM(input, pemType) {
		return input, nil
	}

	// Remove any existing "BEGIN" or "END" markers, just in case
	input = strings.ReplaceAll(input, fmt.Sprintf("-----BEGIN %s-----", pemType), "")
	input = strings.ReplaceAll(input, fmt.Sprintf("-----END %s-----", pemType), "")
	input = strings.TrimSpace(input) // Remove surrounding whitespace

	// Validate the Base64 string
	if err := validateBase64(input); err != nil {
		return "", fmt.Errorf("invalid Base64 input: %v", err)
	}

	// Format the key into PEM format
	const lineLength = 64
	var builder strings.Builder
	builder.WriteString(fmt.Sprintf("-----BEGIN %s-----\n", pemType))
	for i := 0; i < len(input); i += lineLength {
		end := i + lineLength
		if end > len(input) {
			end = len(input)
		}
		builder.WriteString(input[i:end] + "\n")
	}
	builder.WriteString(fmt.Sprintf("-----END %s-----", pemType))
	return builder.String(), nil
}

// isPEM checks if the input is already in PEM format.
func isPEM(input, pemType string) bool {
	return strings.HasPrefix(input, fmt.Sprintf("-----BEGIN %s-----", pemType)) &&
		strings.Contains(input, "\n") &&
		strings.HasSuffix(input, fmt.Sprintf("-----END %s-----", pemType))
}

// validateBase64 checks if the input is a valid Base64 string.
func validateBase64(input string) error {
	_, err := base64.StdEncoding.DecodeString(input)
	return err
}

const (
	CHAR_SET               = "UTF-8"
	BASE_64_FORMAT         = "UrlSafeNoPadding"
	RSA_ALGORITHM_KEY_TYPE = "PKCS8"
	RSA_ALGORITHM_SIGN     = crypto.SHA256
)

type XRsa struct {
	publicKey  *rsa.PublicKey
	privateKey *rsa.PrivateKey
}

// 生成密钥对
func CreateKeys(publicKeyWriter, privateKeyWriter io.Writer, keyLength int) error {
	// 生成私钥文件
	privateKey, err := rsa.GenerateKey(rand.Reader, keyLength)
	if err != nil {
		return err
	}
	//PKCS8 比 PKCS1更加通用
	derStream := MarshalPKCS8PrivateKey(privateKey)
	block := &pem.Block{
		Type:  "PRIVATE KEY",
		Bytes: derStream,
	}
	err = pem.Encode(privateKeyWriter, block)
	if err != nil {
		return err
	}
	// 生成公钥文件
	publicKey := &privateKey.PublicKey
	derPkix, err := x509.MarshalPKIXPublicKey(publicKey)
	if err != nil {
		return err
	}
	block = &pem.Block{
		Type:  "PUBLIC KEY",
		Bytes: derPkix,
	}
	err = pem.Encode(publicKeyWriter, block)
	if err != nil {
		return err
	}
	return nil
}
func NewXRsa(publicKey []byte, privateKey []byte) (*XRsa, error) {
	var pub *rsa.PublicKey
	var priv *rsa.PrivateKey

	if publicKey != nil {
		block, _ := pem.Decode(publicKey)
		if block == nil {
			return nil, errors.New("public key error")
		}
		switch block.Type {
		case "RSA PUBLIC KEY":
			pubInterface, err := x509.ParsePKCS1PublicKey(block.Bytes)
			if err != nil {
				return nil, err
			}
			pub = pubInterface
		default:
			pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes)
			if err != nil {
				return nil, err
			}
			var ok1 bool
			pub, ok1 = pubInterface.(*rsa.PublicKey)
			if !ok1 {
				return nil, errors.New("public key not supported")
			}
		}

		log.Println("公钥长度", pub.Size()*8)
	}

	if privateKey != nil {
		block, _ := pem.Decode(privateKey)
		if block == nil {
			return nil, errors.New("private key error!")
		}

		switch block.Type {
		case "OPENSSH PRIVATE KEY":
			log.Println("Here at OPENSSH Private Key:")
			key, err := ssh.ParseRawPrivateKey(privateKey)
			if err != nil {
				log.Println(err)
			} else {
				var ok1 bool
				priv, ok1 = key.(*rsa.PrivateKey)
				if !ok1 {
					return nil, errors.New("private key not supported")
				}
			}
		case "RSA PRIVATE KEY":
			priv1, err := x509.ParsePKCS1PrivateKey(block.Bytes)
			if err != nil {
				return nil, err
			}
			priv = priv1
		default:
			priv1, err := x509.ParsePKCS8PrivateKey(block.Bytes)
			if err != nil {
				return nil, err
			}
			var ok1 bool
			priv, ok1 = priv1.(*rsa.PrivateKey)
			if !ok1 {
				return nil, errors.New("private key not supported")
			}

		}
		log.Println("私钥长度", priv.Size()*8)
	}
	return &XRsa{
		publicKey:  pub,
		privateKey: priv,
	}, nil
}

// 公钥加密
func (r *XRsa) PublicEncrypt256(data string) (string, error) {

	oaep, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, r.publicKey, []byte(data), nil)
	if err != nil {
		log.Println(err.Error())
	}
	return base64.RawStdEncoding.EncodeToString(oaep), nil
}

// 私钥解密
func (r *XRsa) PrivateDecrypt256(encrypted string) (string, error) {
	encryptedBytes, err := base64.RawStdEncoding.DecodeString(encrypted)
	if err != nil {
		return "", err
	}
	// 使用 OAEP-SHA256 填充
	decryptedBytes, err := rsa.DecryptOAEP(
		sha256.New(),
		rand.Reader,
		r.privateKey,
		encryptedBytes,
		nil, // 必须与加密时的标签一致
	)
	if err != nil {
		return "", err
	}
	return string(decryptedBytes), nil
}

// 公钥加密
func (r *XRsa) PublicEncrypt(data string) (string, error) {
	partLen := r.publicKey.N.BitLen()/8 - 11
	chunks := split([]byte(data), partLen)
	buffer := bytes.NewBufferString("")
	for _, chunk := range chunks {
		bytes, err := rsa.EncryptPKCS1v15(rand.Reader, r.publicKey, chunk)
		if err != nil {
			return "", err
		}
		buffer.Write(bytes)
	}
	return base64.RawURLEncoding.EncodeToString(buffer.Bytes()), nil
}

// 私钥解密
func (r *XRsa) PrivateDecrypt(encrypted string) (string, error) {
	partLen := r.publicKey.N.BitLen() / 8
	raw, err := base64.RawURLEncoding.DecodeString(encrypted)
	chunks := split(raw, partLen)
	buffer := bytes.NewBufferString("")
	for _, chunk := range chunks {
		decrypted, err := rsa.DecryptPKCS1v15(rand.Reader, r.privateKey, chunk)
		if err != nil {
			return "", err
		}
		buffer.Write(decrypted)
	}
	return buffer.String(), err
}

// 数据加签
func (r *XRsa) Sign(data string) (string, error) {
	h := RSA_ALGORITHM_SIGN.New()
	h.Write([]byte(data))
	hashed := h.Sum(nil)
	sign, err := rsa.SignPKCS1v15(rand.Reader, r.privateKey, RSA_ALGORITHM_SIGN, hashed)
	if err != nil {
		return "", err
	}
	return base64.RawURLEncoding.EncodeToString(sign), err
}

// 数据验签
func (r *XRsa) Verify(data string, sign string) error {
	h := RSA_ALGORITHM_SIGN.New()
	h.Write([]byte(data))
	hashed := h.Sum(nil)
	decodedSign, err := base64.RawURLEncoding.DecodeString(sign)
	if err != nil {
		return err
	}
	return rsa.VerifyPKCS1v15(r.publicKey, RSA_ALGORITHM_SIGN, hashed, decodedSign)
}
func MarshalPKCS8PrivateKey(key *rsa.PrivateKey) []byte {
	info := struct {
		Version             int
		PrivateKeyAlgorithm []asn1.ObjectIdentifier
		PrivateKey          []byte
	}{}
	info.Version = 0
	info.PrivateKeyAlgorithm = make([]asn1.ObjectIdentifier, 1)
	info.PrivateKeyAlgorithm[0] = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1}
	info.PrivateKey = x509.MarshalPKCS1PrivateKey(key)
	k, _ := asn1.Marshal(info)
	return k
}
func split(buf []byte, lim int) [][]byte {
	var chunk []byte
	chunks := make([][]byte, 0, len(buf)/lim+1)
	for len(buf) >= lim {
		chunk, buf = buf[:lim], buf[lim:]
		chunks = append(chunks, chunk)
	}
	if len(buf) > 0 {
		chunks = append(chunks, buf[:len(buf)])
	}
	return chunks
}

// ===== Ed25519 基于 X25519 + AES-GCM 的加解密实现 =====

// EdPubToCurvePub: Ed25519 公钥转换到 X25519 公钥（RFC 7748）。
func EdPubToCurvePub(edPub ed25519.PublicKey) ([32]byte, error) {
	if len(edPub) != 32 {
		return [32]byte{}, errors.New("invalid ed25519 public key")
	}
	var edYBytes [32]byte
	copy(edYBytes[:], edPub)
	edYBytes[31] &= 0x7F // clear sign bit

	// 反转到大端便于大整数处理
	reverse := func(b []byte) []byte {
		r := make([]byte, len(b))
		for i := range b {
			r[i] = b[len(b)-1-i]
		}
		return r
	}

	p := func() *big.Int {
		one := big.NewInt(1)
		pval := new(big.Int).Lsh(one, 255)
		pval.Sub(pval, big.NewInt(19))
		return pval
	}()

	y := new(big.Int).SetBytes(reverse(edYBytes[:]))
	one := big.NewInt(1)
	yMinusOne := new(big.Int).Sub(one, y)
	yMinusOne.Mod(yMinusOne, p)
	yPlusOne := new(big.Int).Add(one, y)
	inv := new(big.Int).ModInverse(yMinusOne, p)
	if inv == nil {
		return [32]byte{}, errors.New("no modular inverse")
	}
	u := new(big.Int).Mul(yPlusOne, inv)
	u.Mod(u, p)

	uBytesBig := u.Bytes()
	var uBytes [32]byte
	copy(uBytes[32-len(uBytesBig):], uBytesBig)
	// 转回小端
	for i := 0; i < 16; i++ {
		uBytes[i], uBytes[31-i] = uBytes[31-i], uBytes[i]
	}
	return uBytes, nil
}

// EdPrivToCurvePriv: Ed25519 私钥导出 X25519 私钥（从 seed 取 H(seed)[0:32]）。
func EdPrivToCurvePriv(edPriv ed25519.PrivateKey) ([32]byte, error) {
	if len(edPriv) != 64 {
		return [32]byte{}, errors.New("invalid ed25519 private key")
	}
	seed := edPriv[:32]
	h := sha512.Sum512(seed)
	var xPriv [32]byte
	copy(xPriv[:], h[:32])
	return xPriv, nil
}

// Ed25519Encrypt: 使用接收方 Ed25519 公钥进行加密
func Ed25519Encrypt(message []byte, recipientPub ed25519.PublicKey) (string, error) {
	curve := ecdh.X25519()
	ephemPriv, err := curve.GenerateKey(rand.Reader)
	if err != nil {
		return "", err
	}
	ephemPubBytes := ephemPriv.PublicKey().Bytes()

	xPub, err := EdPubToCurvePub(recipientPub)
	if err != nil {
		return "", err
	}
	peerPubKey, err := curve.NewPublicKey(xPub[:])
	if err != nil {
		return "", err
	}

	shared, err := ephemPriv.ECDH(peerPubKey)
	if err != nil {
		return "", err
	}

	keyHash := sha256.Sum256(shared)
	block, err := aes.NewCipher(keyHash[:])
	if err != nil {
		return "", err
	}
	aesgcm, err := cipher.NewGCM(block)
	if err != nil {
		return "", err
	}
	nonce := make([]byte, aesgcm.NonceSize())
	if _, err = rand.Read(nonce); err != nil {
		return "", err
	}
	ciphertext := aesgcm.Seal(nil, nonce, message, nil)
	packed := append(ephemPubBytes, nonce...)
	packed = append(packed, ciphertext...)
	return base64.StdEncoding.EncodeToString(packed), nil
}

// Ed25519Decrypt: 使用接收方 Ed25519 私钥进行解密
func Ed25519Decrypt(base64Packed string, recipientPriv ed25519.PrivateKey) ([]byte, error) {
	packed, err := base64.StdEncoding.DecodeString(strings.TrimSpace(base64Packed))
	if err != nil {
		return nil, err
	}
	if len(packed) < 32+12 {
		return nil, errors.New("invalid packed data")
	}
	ephemPubBytes := make([]byte, 32)
	copy(ephemPubBytes, packed[:32])
	nonce := packed[32 : 32+12]
	ct := packed[32+12:]

	curve := ecdh.X25519()
	peerPubKey, err := curve.NewPublicKey(ephemPubBytes)
	if err != nil {
		return nil, err
	}
	xPriv, err := EdPrivToCurvePriv(recipientPriv)
	if err != nil {
		return nil, err
	}
	privKey, err := curve.NewPrivateKey(xPriv[:])
	if err != nil {
		return nil, err
	}
	shared, err := privKey.ECDH(peerPubKey)
	if err != nil {
		return nil, err
	}
	keyHash := sha256.Sum256(shared)
	block, err := aes.NewCipher(keyHash[:])
	if err != nil {
		return nil, err
	}
	aesgcm, err := cipher.NewGCM(block)
	if err != nil {
		return nil, err
	}
	return aesgcm.Open(nil, nonce, ct, nil)
}
